Skip to content

GT relation aware attn#599

Draft
yliu2-sc wants to merge 8 commits intomainfrom
yliu2/gt_relation_aware_attn
Draft

GT relation aware attn#599
yliu2-sc wants to merge 8 commits intomainfrom
yliu2/gt_relation_aware_attn

Conversation

@yliu2-sc
Copy link
Copy Markdown
Collaborator

Scope of work done

Where is the documentation for this feature?: N/A

Did you add automated tests or write a test plan?

Updated Changelog.md? NO

Ready for code review?: NO

self._relation_attention_matrices: Optional[nn.Parameter] = None
if relation_attention_mode == "edge_type_additive":
self._relation_attention_matrices = nn.Parameter(
torch.empty(num_relations, num_heads, self._head_dim, self._head_dim)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we look at HGT's equation 3, the W^(ATT)_{\phi(e)} does't have a super script i, so I think the edge type specific transformation is per edge type, not per (edge type, head index). The current implementation has more capacity but we are not doing apple-to-apple comparison with HGT

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it seems like we are missing the \mu attention multiplier per (source_type, edge_type, destination_type) in equation 3 if we compare with HGT

pairwise_relation_mask=pairwise_relation_mask,
)
if relation_attention_bias is not None:
attn_bias = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HGT uses KWQ^T, whereas we use KQ^T + KW*Q^T. So basically we reparametrized HGT's formula with W'^ATT = I + W^ATT. I think we have the same expressiveness, but maybe should initialize our W^ATT differently. zero-init or small sigma gaussian init could be an option, as xavier could make the variance a bit too big for bias. I'm open to discussion

active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True)

for relation_idx_tensor in active_relation_ids:
relation_idx = int(relation_idx_tensor.item())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this .item() can the above active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) causes GPU sync. Some AI suggestions to avoid it for performance

# remove the torch.unique line above
for relation_idx in range(self._num_relations):
    # remove the .item() line
    relation_positions = active_relation_positions[
        active_relation_positions[:, 3] == relation_idx
    ]
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants